{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "JAX Colab TPU Test",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google/jax/blob/master/tests/notebooks/colab_tpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WkadOyTDCAWD",
        "colab_type": "text"
      },
      "source": [
        "# JAX Colab TPU Test\n",
        "\n",
        "This notebook is meant to be run in a [Colab](http://colab.research.google.com) TPU runtime as a basic check for JAX updates."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_tKNrbqqBHwu",
        "colab_type": "code",
        "outputId": "bf0043b0-6f2b-44e4-9822-4f426b3d158e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "import jax\n",
        "import jaxlib\n",
        "\n",
        "!cat /var/colab/hostname\n",
        "print(jax.__version__)\n",
        "print(jaxlib.__version__)"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tpu-s-2dna7uebo6z96\n",
            "0.1.64\n",
            "0.1.45\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DzVStuLobcoG",
        "colab_type": "text"
      },
      "source": [
        "## TPU Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IXF0_gNCRH08",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "e66e0f60-6e57-44ed-ba6a-2d3fa906b101"
      },
      "source": [
        "import requests\n",
        "import os\n",
        "if 'TPU_DRIVER_MODE' not in globals():\n",
        "  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'\n",
        "  resp = requests.post(url)\n",
        "  assert resp.status_code == 200\n",
        "  TPU_DRIVER_MODE = 1\n",
        "\n",
        "# The following is required to use TPU Driver as JAX's backend.\n",
        "from jax.config import config\n",
        "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
        "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
        "print(config.FLAGS.jax_backend_target)"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "grpc://10.69.129.170:8470\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oqEG21rADO1F",
        "colab_type": "text"
      },
      "source": [
        "## Confirm Device"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8BwzMYhKGQj6",
        "outputId": "d51b7f21-d300-4420-8c5c-483bace8617d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "from jaxlib import tpu_client_extension\n",
        "import jax\n",
        "key = jax.random.PRNGKey(1701)\n",
        "arr = jax.random.normal(key, (1000,))\n",
        "device = arr.device_buffer.device()\n",
        "print(f\"JAX device type: {device}\")\n",
        "assert isinstance(device, tpu_client_extension.TpuDevice), \"unexpected JAX device type\""
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "JAX device type: TPU_0(host=0,(0,0,0,0))\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z0FUY9yUC4k1",
        "colab_type": "text"
      },
      "source": [
        "## Matrix Multiplication"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eXn8GUl6CG5N",
        "outputId": "9954a064-ef8b-4db3-aad7-85d07b50f678",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import jax\n",
        "import numpy as np\n",
        "\n",
        "# matrix multiplication on GPU\n",
        "key = jax.random.PRNGKey(0)\n",
        "x = jax.random.normal(key, (3000, 3000))\n",
        "result = jax.numpy.dot(x, x.T).mean()\n",
        "print(result)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "1.021576\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jCyKUn4-DCXn",
        "colab_type": "text"
      },
      "source": [
        "## XLA Compilation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2GOn_HhDPuEn",
        "outputId": "a4384c55-41fb-44be-845d-17b86b152068",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "@jax.jit\n",
        "def selu(x, alpha=1.67, lmbda=1.05):\n",
        "  return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
        "x = jax.random.normal(key, (5000,))\n",
        "result = selu(x).block_until_ready()\n",
        "print(result)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[ 0.34676817 -0.7532211   1.7060809  ...  2.120809   -0.42622015\n",
            "  0.13093244]\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}